{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "c2acbdea-bdce-43f8-86d4-fb602f67beba",
   "metadata": {},
   "source": [
    "# Example 07 - Sig53 with YOLOv8 Classifier\n",
    "This notebook showcases using the Sig53 dataset to train a YOLOv8 classification model.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2addcb1-6b3d-4484-a9b9-20d161840d95",
   "metadata": {},
   "source": [
    "## Import Libraries\n",
    "We will import all the usual libraries, in addition to Ultralytics. You can install Ultralytics with:\n",
    "```bash\n",
    "pip install ultralytics\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb874df6-3e3b-4fc0-a446-25f90923ec9e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Packages for Training\n",
    "from torchsig.utils.yolo_classify import *\n",
    "from torchsig.utils.classify_transforms import real_imag_vstacked_cwt_image, complex_iq_to_heatmap\n",
    "import yaml\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "810529c8-adf4-40d1-b735-455a72df2902",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Packages for testing/inference\n",
    "from torchsig.datasets.modulations import ModulationsDataset\n",
    "from torchsig.transforms.target_transforms import DescToFamilyName\n",
    "from torchsig.transforms.transforms import Spectrogram, SpectrogramImage, Normalize, Compose, Identity\n",
    "from ultralytics import YOLO\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c63eff7e-5dba-459b-88ba-d2e855c8cd27",
   "metadata": {},
   "source": [
    "## Prepare YOLO classificatoin trainer and Model\n",
    "Datasets are generated on the fly in a way that is Ultralytics YOLO compatible. See [Ultralytics: Train Custom Data - Organize Directories](https://docs.ultralytics.com/yolov5/tutorials/train_custom_data/#23-organize-directories) to learn more. \n",
    "\n",
    "Additionally, we create a yaml file for dataset configuration. See \"classify.yaml\" in Torchsig Examples.\n",
    "\n",
    "Download desired YOLO model from [Ultralytics Models](https://docs.ultralytics.com/models/). We will use YOLOv8, specifically `yolov8n-cls.pt`\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1ddb6903-a7ec-4f3e-8555-8132cbfbc4fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "config_path = 'classify.yaml'\n",
    "with open(config_path, 'r') as file:\n",
    "    config = yaml.safe_load(file)\n",
    "\n",
    "overrides = config['overrides']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3acf456a-4d02-4ff8-b44c-f0b0c3014723",
   "metadata": {},
   "source": [
    "### Explanation of the `overrides` Dictionary\n",
    "\n",
    "The `overrides` dictionary is used to customize the settings for the Ultralytics YOLO trainer by specifying specific values that override the default configurations. The dictionary is imported from `classify.yaml`. However, you can customize in the notebook. \n",
    "\n",
    "Example:\n",
    "\n",
    "```python\n",
    "overrides = {'model': 'yolov8n-cls.pt', 'epochs': 100, 'data': 'classify.yaml', 'device': 0, 'imgsz': 64}\n",
    "```\n",
    "A .yaml is necessary for training. Look at `classify.yaml` in the examples directory. It will contain the path to your torchsig data.\n",
    "\n",
    "### Explanation of `image_transform` function\n",
    "`YoloClassifyTrainer` allows you to pass in any transform that takes in complex I/Q and outputs an image for training. Some example transforms can be found in torchsig.utils.classify_transforms. If nothing is passed, it will default to spectrogram images. It is important to update `overrides` so that your imgsz matches output."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e9d94dba-1b87-4839-8c34-a2b449ede80d",
   "metadata": {},
   "source": [
    "### Build YoloClassifyTrainer\n",
    "This will instantiate the YOLO classification trainer with overrides specified above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f134e461-456a-4be7-9c91-ce26c5c4f850",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = YoloClassifyTrainer(overrides=overrides, image_transform=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d1d1cb68-5452-498b-bf5b-a7b96478429c",
   "metadata": {},
   "source": [
    "### The will begin training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51d670a0-54e3-4a37-9649-740412e302ef",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "454e71d2-fc9c-4f60-9545-bb0997fb5334",
   "metadata": {},
   "source": [
    "### Instantiate Test Dataset\n",
    "\n",
    "Uses Torchsig's `ModulationsDataset` to generate a narrowband classification dataset. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce307e25-99cb-4c72-9910-3d18d89d9c5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Determine whether to map descriptions to family names\n",
    "if config['family']:\n",
    "    target_transform = CP([DescToFamilyName()])\n",
    "else:\n",
    "    target_transform = None\n",
    "\n",
    "transform = Compose([\n",
    "    Spectrogram(nperseg=overrides['imgsz'], noverlap=0, nfft=overrides['imgsz'], mode='psd'),\n",
    "    Normalize(norm=np.inf, flatten=True),\n",
    "    SpectrogramImage(), \n",
    "    ])\n",
    "\n",
    "class_list = [item[1] for item in config['names'].items()]\n",
    "\n",
    "dataset = ModulationsDataset(\n",
    "    classes=class_list,\n",
    "    use_class_idx=False,\n",
    "    level=config['level'],\n",
    "    num_iq_samples=overrides['imgsz']**2,\n",
    "    num_samples=int(config['nc'] * 10),\n",
    "    include_snr=config['include_snr'],\n",
    "    transform=transform,\n",
    "    target_transform=target_transform\n",
    ")\n",
    "\n",
    "# Retrieve a sample and print out information\n",
    "idx = np.random.randint(len(dataset))\n",
    "data, label = dataset[idx]\n",
    "print(\"Dataset length: {}\".format(len(dataset)))\n",
    "print(\"Data shape: {}\".format(data.shape))\n",
    "\n",
    "samples = []\n",
    "labels = []\n",
    "for i in range(10):\n",
    "    idx = np.random.randint(len(dataset))\n",
    "    sample, label = dataset[idx]\n",
    "    samples.append(sample)\n",
    "    labels.append(label)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "caab72b6-df31-4a08-88c7-b37400aec5d2",
   "metadata": {},
   "source": [
    "### Predictions / Inference\n",
    "The following cells show you how to load the 'best.pt' weights from your training for prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5644b3e9-a7c8-4865-8d16-f47b54cf7606",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_path = 'YOUR_PROJECT_NAME/YOUR_CLASSIFY_EXPERIMENT/weights/best.pt'  #replace with your path to 'best.pt'\n",
    "model = YOLO(model_path) #The model will remember the configuration from training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "295ffbe6-4f05-4c01-bb49-1652869b1333",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = model.predict(samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "135974bd-c0b2-478f-895b-96d487b9d33e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Process results list\n",
    "for y, result in enumerate(results):\n",
    "    probs = result.probs  # Probs object for classification outputs\n",
    "    print(f'Actual Labels -> {labels[y]}')\n",
    "    print(f'Top 1 Prediction ->  {result.names[probs.top1]}, {probs.top1conf}')\n",
    "    print(f'Top 5 Prediction ->  {result.names[probs.top5[0]]},{result.names[probs.top5[1]]},{result.names[probs.top5[2]]},{result.names[probs.top5[3]]},{result.names[probs.top5[4]]}, {list(probs.top5conf.cpu().numpy())}')\n",
    "\n",
    "    img = Image.fromarray(result.orig_img)\n",
    "    img.show()# display to screen\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
